{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from thuglm.modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "from transformers import AutoConfig, AutoTokenizer, AutoModel\n",
    "import os \n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"D:\\\\训练代码\\\\ptuning\\\\output\\\\adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR\\\\checkpoint-50\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "config.pre_seq_len = 8\n",
    "config.prefix_projection = False# model_args.prefix_projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dc03da0da0b4b49967a93810acfdfae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at THUDM/chatglm-6b and are newly initialized: ['transformer.prefix_encoder.embedding.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModel.from_pretrained(\"THUDM/chatglm-6b\", config=config, trust_remote_code=True)\n",
    "prefix_state_dict = torch.load(os.path.join(model_name_or_path, \"pytorch_model.bin\"))\n",
    "new_prefix_state_dict = {}\n",
    "for k, v in prefix_state_dict.items():\n",
    "    if k.startswith(\"transformer.prefix_encoder.\"):\n",
    "        new_prefix_state_dict[k[len(\"transformer.prefix_encoder.\"):]] = v\n",
    "model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PrefixEncoder(\n",
       "  (embedding): Embedding(8, 229376)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = model.half()\n",
    "model.transformer.prefix_encoder.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The dtype of attention mask (torch.int64) is not bool\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你好，我是ChatGLM。很高兴能和你聊天。\n"
     ]
    }
   ],
   "source": [
    "response, history = model.chat(tokenizer, \"你好\", history=[])\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
